# Script which deterministically generates certificates given a definitions file.
import argparse
import datetime
import hashlib
import ipaddress
from pathlib import PurePath
from typing import Any, Dict

import asn1crypto.core as asn1
import cryptography.hazmat.primitives.serialization.pkcs12 as pkcs12
import yaml
from cryptography import x509
from cryptography.hazmat._oid import _OID_NAMES
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.x509.extensions import UnrecognizedExtension
from cryptography.x509.oid import ObjectIdentifier
from ecdsa import SigningKey

# Dictionary from common names to OIDs.
NAME_TO_OID = {v: k for k, v in _OID_NAMES.items()}
# Map short names that we explicitly support to their corresponding OIDs.
NAME_TO_OID["C"] = NAME_TO_OID["countryName"]
NAME_TO_OID["ST"] = NAME_TO_OID["stateOrProvinceName"]
NAME_TO_OID["O"] = NAME_TO_OID["organizationName"]
NAME_TO_OID["OU"] = NAME_TO_OID["organizationalUnitName"]
NAME_TO_OID["L"] = NAME_TO_OID["localityName"]
NAME_TO_OID["SN"] = NAME_TO_OID["surname"]
NAME_TO_OID["CN"] = NAME_TO_OID["commonName"]

# Path to the file specifying the config.
CONFIGFILE = None

# Config parsed as YAML.
CONFIG = Dict[str, Any]

# <= 825 in order to abide by https://support.apple.com/en-us/HT210176.
MAX_VALIDITY_PERIOD_DAYS = 824

# Datetime to specify as the start time for all certs.
# TODO SERVER-101469 Make this a command-line argument and add it as a Bazel input so that Bazel
# knows when to rerun certificate generation.
DEFAULT_START_TIME = datetime.datetime(datetime.datetime.now().year, 1, 1)
# Allocate serial numbers sequentially; this is the last-used serial.
LAST_SERIAL_NUMBER = 999
# Cache the private key objects for static/*key.pem.
LOADED_KEYS = {}
# Base directory where outputs go.
OUTPUT_PATH = None
# Base path to keys that are used during generation.
STATIC_PATH = None

LOADED_CERT_AND_KEYS = {}


def get_next_serial():
    """Get the next sequential serial number to use."""
    global LAST_SERIAL_NUMBER
    # Serial numbers 0..999 are reserved for fixed serial numbers.
    # Start at 1000 and increment every time we generate a cert.
    LAST_SERIAL_NUMBER += 1
    return LAST_SERIAL_NUMBER


def get_key(cert):
    """Get the private key object loaded from keyfile."""
    keyfile = idx(cert, "keyfile")
    if keyfile is None:
        raise ValueError("All certificates require a keyfile")

    if keyfile not in LOADED_KEYS:
        passphrase = cert.get("passphrase")
        if passphrase is not None:
            passphrase = bytes(passphrase, "ascii")
        with open(str(STATIC_PATH / keyfile), "rb") as f:
            LOADED_KEYS[keyfile] = serialization.load_pem_private_key(
                f.read(),
                password=passphrase,
            )
    return LOADED_KEYS[keyfile]


def glbl(key, default=None):
    """Fetch a key from the global dict."""
    return CONFIG.get("global", {}).get(key, default)


def idx(cert, key, default=None):
    """Fetch a key from the cert dict, falling back through global dict."""
    return cert.get(key, None) or glbl(key, default)


def make_filename(cert):
    """Form a pathname from a certificate definition."""
    return str(OUTPUT_PATH / cert["name"])


def find_certificate_definition(name):
    """Locate a definition by name."""
    for ca_cert in CONFIG["certs"]:
        if ca_cert["name"] == name:
            return ca_cert

    return None


def get_header_comment(cert):
    """Get the correct header comment for the certificate."""
    if not cert.get("include_header", True):
        return ""
    """Header comment for every generated file."""
    comment = "# Autogenerated file, do not edit.\n"
    comment = comment + "# Generate using python -m x509.mkcert --config " + CONFIGFILE
    comment = comment + " " + cert["name"] + "\n#\n"
    comment = comment + "# " + cert.get("description", "").replace("\n", "\n# ")
    comment = comment + "\n"
    return comment


def get_cert_and_key(cert_name):
    """Locate the cert and key file for a given cert name, load them, and return them."""
    if cert_name in LOADED_CERT_AND_KEYS:  # Cache hit, don't need to load again
        return LOADED_CERT_AND_KEYS[cert_name]
    ca_cert = find_certificate_definition(cert_name)
    if ca_cert:
        with open(make_filename(ca_cert), "rb") as f:
            pem = f.read()
        certificate = x509.load_pem_x509_certificate(pem)
        passphrase = ca_cert.get("passphrase", None)
        if passphrase:
            passphrase = passphrase.encode("utf-8")

        key = serialization.load_pem_private_key(
            pem,
            password=passphrase,
        )
        LOADED_CERT_AND_KEYS[cert_name] = (certificate, key)
        return (certificate, key)
    # Externally sourced certifiate, try by path. Hopefully unencrypted.
    with open(cert_name, "rb") as f:
        pem = f.read()
    certificate = x509.load_pem_x509_certificate(pem)
    key = serialization.load_pem_private_key(pem, password=None)
    LOADED_CERT_AND_KEYS[cert_name] = (certificate, key)
    return (certificate, key)


def get_validity_period(cert):
    """Get the validity range for the certificate."""
    start_shift_secs = int(idx(cert, "not_before", 0))
    end_shift_secs = int(
        idx(cert, "not_after", start_shift_secs + MAX_VALIDITY_PERIOD_DAYS * 24 * 60 * 60)
    )

    start_time = DEFAULT_START_TIME + datetime.timedelta(seconds=start_shift_secs)
    end_time = DEFAULT_START_TIME + datetime.timedelta(seconds=end_shift_secs)
    return start_time, end_time


def get_oid(cn_or_oid):
    """Given a string containing an OID or a common name, return the corresponding OID object."""
    if cn_or_oid in NAME_TO_OID:
        return NAME_TO_OID[cn_or_oid]
    try:
        return ObjectIdentifier(cn_or_oid)
    except:
        raise ValueError(f"Name attribute {cn_or_oid} not recognized")


def set_subject(builder, cert, set_issuer=False):
    """Set the subject on the certificate builder according to the certificate definition. Also set the issuer to the same thing if set_issuer is true."""
    if not cert.get("Subject"):
        if cert.get("explicit_subject", False):
            # do nothing if an empty subject is explicitly provided
            if set_issuer:
                builder = builder.issuer_name(x509.Name([]))
            return builder.subject_name(x509.Name([]))
        raise ValueError(cert["name"] + " requires a Subject")

    attr_dict = {}
    if not cert.get("explicit_subject", False):
        # Load the globally defined subject RDNs
        for key, val in glbl("Subject", {}).items():
            oid = get_oid(key)
            attr_dict[oid] = val

    if isinstance(cert["Subject"], dict):
        # Normal case: Load the subject RDNs defined by the certificate over the globally defined ones
        for key, val in cert["Subject"].items():
            oid = get_oid(key)
            attr_dict[oid] = val
        name = x509.Name([x509.NameAttribute(key, val) for key, val in attr_dict.items()])
    else:
        # Multivalued RDN case
        assert isinstance(cert["Subject"], list)
        assert cert[
            "explicit_subject"
        ], "explicit_subject must be set to true when using multivalued RDNs"
        rdns = []
        for rdn_def in cert["Subject"]:
            attrs = []
            for key, val in rdn_def.items():
                oid = get_oid(key)
                attrs.append(x509.NameAttribute(oid, val))
            rdns.append(x509.RelativeDistinguishedName(attrs))
        name = x509.Name(rdns)

    if set_issuer:  # When issuer = self, set the issuer as well
        builder = builder.issuer_name(name)
    return builder.subject_name(name)


def set_validity(builder, cert):
    """Set the not_valid_before/after fields on the certificate builder according to the certificate definition."""
    start, end = get_validity_period(cert)
    builder = builder.not_valid_before(start)
    return builder.not_valid_after(end)


def to_der_varint(val):
    """Translate a native int to a variable length ASN.1 encoded integer."""
    if val < 0:
        raise ValueError("Negative values nor permitted in DER payload")

    if val < 0x80:
        return chr(val).encode("ascii")

    ret = bytearray(b"")
    while (val > 0) and (len(ret) < 8):
        ret.insert(0, val & 0xFF)
        val = val >> 8

    if val > 0:
        raise ValueError("Length is too large to represent in 64bits")

    ret.insert(0, 0x80 + len(ret))
    return ret


def to_der_utf8_string(val):
    """Encode a unicode string as a ASN.1 UTF8 String."""
    utf8_val = str(val).encode("utf-8")
    return b"\x0c" + to_der_varint(len(utf8_val)) + utf8_val


def to_der_sequence_pair(name, value):
    """Encode a pair of ASN.1 values as a sequence pair."""
    # Simplified sequence which always expects two string, a key and a value.
    bin_name = to_der_utf8_string(name)
    bin_value = to_der_utf8_string(value)
    return b"\x30" + to_der_varint(len(bin_name) + len(bin_value)) + bin_name + bin_value


class ExtensionParser:
    """Collection of methods to convert extension definitions into cryptography extension objects."""

    @staticmethod
    def basic_constraints(v, **_):
        return x509.BasicConstraints(ca=v.get("CA", False), path_length=v.get("pathlen"))

    @staticmethod
    def key_usage(v, **_):
        to_param_name = {
            "digitalSignature": "digital_signature",
            "nonRepudiation": "content_commitment",
            "keyEncipherment": "key_encipherment",
            "dataEncipherment": "data_encipherment",
            "keyAgreement": "key_agreement",
            "keyCertSign": "key_cert_sign",
            "cRLSign": "crl_sign",
            "encipherOnly": "encipher_only",
            "decipherOnly": "decipher_only",
        }
        params = {name: False for name in to_param_name.values()}
        for usage in v:
            if usage in to_param_name:
                params[to_param_name[usage]] = True
        return x509.KeyUsage(**params)

    @staticmethod
    def ext_usage_name_to_oid(name):
        ext_usage_name_map = {
            "serverAuth": 1,
            "clientAuth": 2,
            "codeSigning": 3,
            "emailProtection": 4,
            "timeStamping": 8,
            "OCSPSigning": 9,
        }
        if name not in ext_usage_name_map:
            raise ValueError(f'Unknown extended key usage identifier: "{name}"')
        return ObjectIdentifier("1.3.6.1.5.5.7.3." + str(ext_usage_name_map[name]))

    @staticmethod
    def extended_key_usage(v, **_):
        return x509.ExtendedKeyUsage([ExtensionParser.ext_usage_name_to_oid(name) for name in v])

    @staticmethod
    def subject_alt_name(v, **_):
        names = []
        for key, val in v.items():
            if key == "critical":
                continue
            elif key == "DNS":
                if not isinstance(val, list):
                    val = [val]
                for name in val:
                    names.append(x509.DNSName(name))
            elif key == "IP":
                if not isinstance(val, list):
                    val = [val]
                for ip in val:
                    names.append(x509.IPAddress(ipaddress.ip_address(ip)))
            else:
                raise ValueError(f'Unknown subject alt name type: "{key}"')
        return x509.SubjectAlternativeName(names)

    @staticmethod
    def subject_key_identifier(v, public_key, **_):
        assert v == "hash"
        return x509.SubjectKeyIdentifier.from_public_key(public_key)

    @staticmethod
    def mongo_roles(v, **_):
        oid = ObjectIdentifier("1.3.6.1.4.1.34601.2.1.1")
        pair = b""
        for role in v:
            if (len(role) != 2) or ("role" not in role) or ("db" not in role):
                raise ValueError("mongoRoles must consist of a series of role/db pairs")
            pair = pair + to_der_sequence_pair(role["role"], role["db"])

        val = b"\x31" + to_der_varint(len(pair)) + pair

        return UnrecognizedExtension(oid, val)

    @staticmethod
    def authority_key_identifier(v, issuer_public_key, issuer_ski, **_):
        if v not in ["keyid", "issuer"]:
            raise ValueError(
                "Only the 'keyid' or 'issuer' values are accepted for authorityKeyIdentifier"
            )

        if v == "issuer":
            return x509.AuthorityKeyIdentifier.from_issuer_public_key(issuer_public_key)
        else:
            return x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(issuer_ski)

    @staticmethod
    def mongo_cluster_membership(v, **_):
        """Encode a symbolic name to a mongodbClusterMembership extension."""
        oid = ObjectIdentifier("1.3.6.1.4.1.34601.2.1.2")
        val = to_der_utf8_string(v)
        return UnrecognizedExtension(oid, val)

    @staticmethod
    def authority_information_access(v, **_):
        if not isinstance(v, list):
            v = [v]
        assert all(entry["method"] == "OCSP" for entry in v)
        return x509.AuthorityInformationAccess(
            [
                x509.AccessDescription(
                    x509.oid.AuthorityInformationAccessOID.OCSP,
                    x509.UniformResourceIdentifier(entry["location"]),
                )
                for entry in v
            ]
        )

    @staticmethod
    def must_staple(v, **_):
        assert v, "If set, mustStaple must be true"
        oid = ObjectIdentifier("1.3.6.1.5.5.7.1.24")
        val = b"\x30\x03\x02\x01\x05"
        return UnrecognizedExtension(oid, val)

    @staticmethod
    def ns_comment(v, **_):
        oid = ObjectIdentifier("2.16.840.1.113730.1.13")
        val = b"\x16\x1d" + bytes(v, "ascii")
        return UnrecognizedExtension(oid, val)

    parsers = {
        "basicConstraints": basic_constraints,
        "keyUsage": key_usage,
        "extendedKeyUsage": extended_key_usage,
        "subjectAltName": subject_alt_name,
        "subjectKeyIdentifier": subject_key_identifier,
        "mongoRoles": mongo_roles,
        "authorityKeyIdentifier": authority_key_identifier,
        "mongoClusterMembership": mongo_cluster_membership,
        "authorityInfoAccess": authority_information_access,
        "mustStaple": must_staple,
        "nsComment": ns_comment,
    }


def set_extensions(builder, cert, **kwargs):
    """Add all the X.509 extensions specified on the certificate definition to the certificate builder."""
    extensions = cert.get("extensions", {})
    for key, val in extensions.items():
        handler = ExtensionParser.parsers.get(key)
        if handler is None:
            raise ValueError(f'Extension "{key}" is not handled yet')
        ext = handler(val, **kwargs)
        if isinstance(val, list):
            critical = "critical" in val
        elif isinstance(val, dict):
            critical = val.get("critical", False)
        elif isinstance(val, str) or isinstance(val, bool):
            critical = False
        else:
            raise ValueError(f"Could not parse extension: {key} -> {val}")
        builder = builder.add_extension(ext, critical=critical)
    return builder


def get_issuer_cert_and_key(cert, key):
    """Get the issuer certificate object (or 'self') and key for the given certificate definition."""
    issuer = cert.get("Issuer")
    if issuer == "self":
        return "self", key

    # Signed by a CA, find the key...
    return get_cert_and_key(issuer)


class SignedCertificateSequence(asn1.Sequence):
    """Python representation of the ASN1 structure of a signed certificate."""

    _fields = [
        ("cert_content", asn1.Sequence),
        ("algo_type", asn1.Sequence),
        ("signature", asn1.BitString),
    ]


def to_bits(bytestr):
    """Convert byte array to bit array."""
    ret = []
    for b in bytestr:
        ret.extend((b >> (7 - i)) % 2 for i in range(8))
    return tuple(ret)


def sign_ecdsa_deterministic(key, cert):
    """Re-sign a signed certificate with the given ECDSA key in a deterministic fashion. Return the newly signed certificate object."""
    ecdsa_pkey = SigningKey.from_pem(
        key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=serialization.NoEncryption(),
        )
    )
    # Get bytes of our signed certificate as DER and load them.
    all_bytes = cert.public_bytes(encoding=serialization.Encoding.DER)
    seq = SignedCertificateSequence.load(all_bytes)
    # Get just the certificate content and sign it.
    cert_bytes = seq["cert_content"].dump()
    sig = ecdsa_pkey.sign_deterministic(cert_bytes, hashfunc=hashlib.sha256)
    # Encode the signature -- Split it in half and make a sequence with the two halves.
    assert len(sig) == 64
    r = sig[:32]
    s = sig[32:]
    ber_sig = b"\x30\x44\x02\x20" + r + b"\x02\x20" + s
    # Set this as the signature, then dump the new certificate.
    seq["signature"] = to_bits(ber_sig)
    signed_bytes = seq.dump()
    # Load the new certificate.
    return x509.load_der_x509_certificate(signed_bytes)


def write_cert_as_pkcs12(cert, key, cert_obj, issuer_obj):
    """Makes a new copy of the cert/key pair using PKCS#12 encoding."""
    pkcs12_opts = cert.get("pkcs12")
    if not pkcs12_opts.get("passphrase"):
        raise ValueError("PKCS#12 requires a passphrase")

    fname = pkcs12_opts.get("name", cert["name"])
    serialized = pkcs12.serialize_key_and_certificates(
        fname.encode("ascii"),
        key,
        cert_obj,
        cas=[issuer_obj],
        encryption_algorithm=serialization.BestAvailableEncryption(
            pkcs12_opts["passphrase"].encode("ascii")
        ),
    )
    with open(OUTPUT_PATH / fname, "wb") as f:
        f.write(serialized)


def process_normal_cert(cert):
    """Given a certificate definition which has a subject, deterministically generate its corresponding certificate file and store it in the output path."""
    key = get_key(cert)
    issuer_cert, issuer_key = get_issuer_cert_and_key(cert, key)
    # Get SKI of issuer if it exists; we need it for the AuthorityKeyIdentifier extension
    if issuer_cert == "self":
        my_ski = cert.get("extensions", {}).get("subjectKeyIdentifier")
        if my_ski is None:
            issuer_ski = None
        else:
            issuer_ski = ExtensionParser.subject_key_identifier(my_ski, key.public_key())
    else:
        try:
            issuer_ski = issuer_cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier)
        except:
            issuer_ski = None

    # Set all fields of the certificate.
    builder = x509.CertificateBuilder()
    builder = builder.public_key(key.public_key())
    serial = cert.get("serial")
    if serial is None:
        serial = get_next_serial()
    else:
        serial = int(serial)
    builder = builder.serial_number(serial)
    builder = set_subject(builder, cert, set_issuer=issuer_cert == "self")
    if issuer_cert != "self":
        builder = builder.issuer_name(issuer_cert.subject)
    builder = set_validity(builder, cert)
    builder = set_extensions(
        builder,
        cert,
        public_key=key.public_key(),
        issuer_public_key=issuer_key.public_key(),
        issuer_ski=issuer_ski,
    )

    if isinstance(key, ec.EllipticCurvePrivateKey):
        # For EC, we need to compute a deterministic signature ourselves. While newer versions of OpenSSL support deterministic signing with ECDSA, some of the platforms we run tests on use old versions, so we unfortunately cannot use this feature.
        bad_sig_obj = builder.sign(key, hashes.SHA256())
        cert_obj = sign_ecdsa_deterministic(key, bad_sig_obj)
    else:
        cert_obj = builder.sign(key, hashes.SHA256())

    header = get_header_comment(cert)
    cert_path = make_filename(cert)
    # Write header + certificate PEM + key PEM to the output file.
    with open(cert_path, "wt") as f:
        f.write(header + cert_obj.public_bytes(serialization.Encoding.PEM).decode("ascii"))
        with open(str(STATIC_PATH / idx(cert, "keyfile")), "r") as keyf:
            f.write(keyf.read())

    LOADED_CERT_AND_KEYS[cert["name"]] = (cert_obj, key)

    if cert.get("split_cert_and_key", False):
        # Write just the certificate to <path>.crt, and just the key to <path>.key
        assert cert_path.endswith(".pem")
        crt_path = cert_path[: -len(".pem")] + ".crt"
        key_path = cert_path[: -len(".pem")] + ".key"

        with open(crt_path, "wt") as f:
            f.write(header + cert_obj.public_bytes(serialization.Encoding.PEM).decode("ascii"))

        with open(key_path, "wt") as f:
            with open(str(STATIC_PATH / idx(cert, "keyfile")), "r") as keyf:
                f.write(header + keyf.read())

    if cert.get("pkcs12", None) is not None:
        write_cert_as_pkcs12(cert, key, cert_obj, issuer_cert)


def process_cert(cert):
    """Given a certificate definition, produce all expected output files and write them to the output directory."""
    print("Processing certificate: " + cert["name"] + ", writing to: " + make_filename(cert))

    append_certs = cert.get("append_cert", [])
    if isinstance(append_certs, str):
        append_certs = [append_certs]

    subject = cert.get("Subject")
    explicit_empty_subject = cert.get("explicit_subject", False) and not subject
    if subject or explicit_empty_subject:
        process_normal_cert(cert)
    elif append_certs:
        # Pure composing certificate. Start with a basic preamble.
        with open(make_filename(cert), "wt") as f:
            f.write(get_header_comment(cert) + "\n")
    else:
        raise ValueError(
            "Certificate definitions must have at least one of 'Subject' and/or 'append_cert'"
        )

    for cert_name in append_certs:
        append_cert = get_cert_and_key(cert_name)[0]
        header = (
            "# Certificate from " + cert_name + "\n" if cert.get("include_header", True) else ""
        )
        with open(make_filename(cert), "at") as f:
            f.write(header + append_cert.public_bytes(serialization.Encoding.PEM).decode("ascii"))


DIGEST_NAME_TO_HASH = {"sha256": hashes.SHA256(), "sha1": hashes.SHA1()}


def write_digest(filename, item_type, digest_type):
    """Calculate the given digest of the certificate/CRL passed in and write it out to <filename>.digest.<digest_type>"""
    assert item_type in {"cert", "crl"}
    assert digest_type in DIGEST_NAME_TO_HASH
    with open(filename, "rb") as f:
        data = f.read()

    if item_type == "cert":
        obj = x509.load_pem_x509_certificate(data)
    else:
        obj = x509.load_pem_x509_crl(data)

    rawdigest = obj.fingerprint(DIGEST_NAME_TO_HASH[digest_type])
    towrite = rawdigest.hex().upper()

    with open(str(filename) + ".digest." + digest_type, "w") as f:
        f.write(towrite)


def generate_crl(issuer_cert, issuer_key, dest, cert_to_revoke=None):
    """Generate a CRL.
    :param issuer_cert: x509.Certificate object which issues this CRL.
    :param issuer_key: Private key object to sign the CRL with.
    :param dest: Path to output CRL to.
    :param cert_to_revoke: x509.Certificate object which this CRL should revoke. Empty for no revocation.
    """
    print(f"Writing CRL: {dest}")
    builder = (
        x509.CertificateRevocationListBuilder()
        .issuer_name(issuer_cert.subject)
        .last_update(DEFAULT_START_TIME)
        .next_update(DEFAULT_START_TIME + datetime.timedelta(days=MAX_VALIDITY_PERIOD_DAYS))
    )

    if cert_to_revoke is not None:
        revoked_builder = (
            x509.RevokedCertificateBuilder()
            .serial_number(cert_to_revoke.serial_number)
            .revocation_date(DEFAULT_START_TIME)
        )
        builder = builder.add_revoked_certificate(revoked_builder.build())

    crl = builder.sign(issuer_key, hashes.SHA256())

    with open(dest, "wb") as f:
        f.write(crl.public_bytes(serialization.Encoding.PEM))

    write_digest(dest, "crl", "sha256")
    write_digest(dest, "crl", "sha1")


def generate_all_crls():
    """Generate all required CRLs. Hardcoded with the expectation that we won't need to add new ones frequently."""
    try:
        ca, ca_key = get_cert_and_key("ca.pem")
        trusted_ca, trusted_ca_key = get_cert_and_key("trusted-ca.pem")
        client_revoked, _ = get_cert_and_key("client_revoked.pem")
        intermediate_ca, intermediate_ca_key = get_cert_and_key("ca.pem")
    except FileNotFoundError as e:
        raise ValueError(
            "ca.pem, trusted-ca.pem, client_revoked.pem, and intermediate-ca-B.pem are required in order to generate CRLs"
        ) from e

    generate_crl(ca, ca_key, OUTPUT_PATH / "crl.pem")
    generate_crl(ca, ca_key, OUTPUT_PATH / "crl_client_revoked.pem", client_revoked)
    generate_crl(ca, ca_key, OUTPUT_PATH / "crl_intermediate_ca_B_revoked.pem", intermediate_ca)
    generate_crl(trusted_ca, trusted_ca_key, OUTPUT_PATH / "crl_from_trusted_ca.pem")
    generate_crl(
        intermediate_ca, intermediate_ca_key, OUTPUT_PATH / "crl_from_intermediate_ca_B.pem"
    )


def parse_command_line():
    """Parse and return the command line arguments."""
    parser = argparse.ArgumentParser(description="X509 Test Certificate Generator")
    parser.add_argument(
        "--config",
        help="Certificate definition file",
        type=str,
        default=str(PurePath("x509/certs.yml")),
    )
    parser.add_argument(
        "--mkcrl",
        action=argparse.BooleanOptionalAction,
        help="Set to generate the default list of CRLs as well",
        default=False,
    )
    parser.add_argument("-o", "--output", help="Output path", type=str, default=str(PurePath(".")))
    parser.add_argument(
        "--static-dir",
        help="Path to directory containing signing keys for certs",
        type=str,
        default=str(PurePath("x509/static")),
    )
    parser.add_argument("cert", nargs="*", help="Certificate to generate (blank for all)")

    args = parser.parse_args()
    return args


def validate_config():
    """Perform basic start up time validation of config file."""
    if not CONFIG.get("certs"):
        raise ValueError("No certificates defined")

    permissible = [
        "name",
        "description",
        "Subject",
        "Issuer",
        "append_cert",
        "extensions",
        "passphrase",
        "include_header",
        "keyfile",
        "split_cert_and_key",
        "explicit_subject",
        "serial",
        "not_before",
        "not_after",
        "pkcs12",
    ]
    for cert in CONFIG.get("certs", []):
        keys = cert.keys()
        if "name" not in keys:
            raise ValueError("Name field required for all certificate definitions")
        if "description" not in keys:
            raise ValueError("description field required for all certificate definitions")
        for key in keys:
            if key not in permissible:
                raise ValueError("Unknown element '" + key + "' in certificate: " + cert["name"])


def select_items(names):
    """Select all certificates requested and their ancestor nodes."""
    if not names:
        return CONFIG["certs"]

    # Temporarily treat like dictionary for easy de-duping.
    ret = {}
    # Start with the cert(s) explicitly asked for.
    for name in names:
        cert = find_certificate_definition(name)
        if not cert:
            raise ValueError("Unknown certificate: " + name)
        ret[name] = cert

    last_count = -1
    while last_count != len(ret):
        last_count = len(ret)
        issuers = {cert.get("Issuer") for _, cert in ret.items()}
        appends = {name for name in cert.get("append_cert", []) for _, cert in ret.items()}
        req_names = issuers | appends
        ret.update({cert["name"]: cert for cert in CONFIG["certs"] if cert["name"] in req_names})

    return ret.values()


def sort_items(items):
    """Ensure that leaves are produced after roots (as much as possible within one file)."""
    all_names = [cert["name"] for cert in items]
    all_names.sort()
    processed_names = set()

    ret = []
    while len(ret) != len(items):
        for cert in items:
            if cert["name"] in processed_names:
                continue

            # only concern ourselves with prependents in this config file.
            unmet_prependents = [
                name
                for name in cert.get("append_cert", [])
                if (name in all_names) and (name not in processed_names)
            ]

            # Self-signed, signed by someone in ret already, or signed externally
            issuer = cert.get("Issuer")
            has_issuer = (
                (issuer == "self") or (issuer in processed_names) or (issuer not in all_names)
            )

            if has_issuer and not unmet_prependents:
                ret.append(cert)
                processed_names.add(cert["name"])

    return ret


def setup_global_state(parsed_args):
    """Set up various global state based on the commandline arguments."""
    global CONFIG, CONFIGFILE, OUTPUT_PATH, STATIC_PATH
    CONFIGFILE = parsed_args.config
    OUTPUT_PATH = PurePath(parsed_args.output)
    STATIC_PATH = PurePath(parsed_args.static_dir)
    with open(CONFIGFILE, "r") as f:
        CONFIG = yaml.load(f, Loader=yaml.FullLoader)
    validate_config()


def main():
    """Go go go."""
    args = parse_command_line()
    setup_global_state(args)

    items_to_process = args.cert or []
    items = select_items(items_to_process)
    items = sort_items(items)
    for item in items:
        process_cert(item)
        filename = make_filename(item)
        write_digest(filename, "cert", "sha256")
        write_digest(filename, "cert", "sha1")
    if args.mkcrl:
        generate_all_crls()


if __name__ == "__main__":
    main()
